# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Any, Callable, Dict, List

import numpy as np
from torch import nn

from flash.core.data.io.input import DataKeys
from flash.core.data.io.input_transform import InputTransform
from flash.core.utilities.imports import _ICEVISION_AVAILABLE, _ICEVISION_GREATER_EQUAL_0_11_0, requires

if _ICEVISION_AVAILABLE:
    from icevision.core import tasks
    from icevision.core.bbox import BBox
    from icevision.core.keypoints import KeyPoints
    from icevision.core.mask import Mask, MaskArray
    from icevision.core.record import BaseRecord
    from icevision.core.record_components import (
        BBoxesRecordComponent,
        ClassMapRecordComponent,
        FilepathRecordComponent,
        ImageRecordComponent,
        InstancesLabelsRecordComponent,
        KeyPointsRecordComponent,
        RecordIDRecordComponent,
    )
    from icevision.data.prediction import Prediction
    from icevision.tfms import A
else:
    MaskArray = object

if _ICEVISION_AVAILABLE and _ICEVISION_GREATER_EQUAL_0_11_0:
    from icevision.core.record_components import InstanceMasksRecordComponent
elif _ICEVISION_AVAILABLE:
    from icevision.core.record_components import MasksRecordComponent


def _split_mask_array(mask_array: MaskArray) -> List[MaskArray]:
    """Utility to split a single ``MaskArray`` object into a list of ``MaskArray`` objects (one per mask)."""
    return [MaskArray(mask) for mask in mask_array.data]


def to_icevision_record(sample: Dict[str, Any]):
    record = BaseRecord([])

    metadata = sample.get(DataKeys.METADATA, None) or {}

    if "image_id" in metadata:
        record_id_component = RecordIDRecordComponent()
        record_id_component.set_record_id(metadata["image_id"])

    component = ClassMapRecordComponent(tasks.detection)
    component.set_class_map(metadata.get("class_map", None))
    record.add_component(component)

    if isinstance(sample[DataKeys.INPUT], str):
        input_component = FilepathRecordComponent()
        input_component.set_filepath(sample[DataKeys.INPUT])
    else:
        if "filepath" in metadata:
            input_component = FilepathRecordComponent()
            input_component.filepath = metadata["filepath"]
        else:
            input_component = ImageRecordComponent()
        input_component.composite = record
        input_component.set_img(sample[DataKeys.INPUT])
    record.add_component(input_component)

    if DataKeys.TARGET in sample:
        if "labels" in sample[DataKeys.TARGET]:
            labels_component = InstancesLabelsRecordComponent()
            labels_component.add_labels_by_id(sample[DataKeys.TARGET]["labels"])
            record.add_component(labels_component)

        if "bboxes" in sample[DataKeys.TARGET]:
            bboxes = [
                BBox.from_xywh(bbox["xmin"], bbox["ymin"], bbox["width"], bbox["height"])
                for bbox in sample[DataKeys.TARGET]["bboxes"]
            ]
            bboxes_component = BBoxesRecordComponent()
            bboxes_component.set_bboxes(bboxes)
            record.add_component(bboxes_component)

        if _ICEVISION_GREATER_EQUAL_0_11_0:
            masks = sample[DataKeys.TARGET].get("masks", None)

            if masks is not None:
                component = InstanceMasksRecordComponent()

                if len(masks) > 0 and isinstance(masks[0], Mask):
                    component.set_masks(masks)
                else:
                    # TODO: This treats invalid examples as negative examples
                    if len(masks) == 0 or not (
                        len(masks) == len(record.detection.bboxes) == len(record.detection.label_ids)
                    ):
                        data = np.zeros((0, record.height, record.width), np.uint8)
                        labels_component.label_ids = []
                        bboxes_component.bboxes = []
                    else:
                        data = np.stack(masks, axis=0)
                    mask_array = MaskArray(data)
                    component.set_mask_array(mask_array)
                    component.set_masks(_split_mask_array(mask_array))

                record.add_component(component)
        else:
            mask_array = sample[DataKeys.TARGET].get("mask_array", None)
            if mask_array is not None:
                component = MasksRecordComponent()
                component.set_masks(mask_array)
                record.add_component(component)

        if "keypoints" in sample[DataKeys.TARGET]:
            keypoints = []

            for keypoints_list, keypoints_metadata in zip(
                sample[DataKeys.TARGET]["keypoints"], sample[DataKeys.TARGET]["keypoints_metadata"]
            ):
                xyv = []
                for keypoint in keypoints_list:
                    xyv.extend((keypoint["x"], keypoint["y"], keypoint["visible"]))

                keypoints.append(KeyPoints.from_xyv(xyv, keypoints_metadata))
            component = KeyPointsRecordComponent()
            component.set_keypoints(keypoints)
            record.add_component(component)

    return record


def from_icevision_detection(record: "BaseRecord"):
    detection = record.detection

    result = {}

    if hasattr(detection, "bboxes"):
        result["bboxes"] = [
            {
                "xmin": bbox.xmin,
                "ymin": bbox.ymin,
                "width": bbox.width,
                "height": bbox.height,
            }
            for bbox in detection.bboxes
        ]

    masks = getattr(detection, "masks", None)
    mask_array = getattr(detection, "mask_array", None)
    if mask_array is not None or not _ICEVISION_GREATER_EQUAL_0_11_0:
        if not isinstance(mask_array, MaskArray) or len(mask_array.data) == 0:
            mask_array = MaskArray.from_masks(masks, record.height, record.width)

        result["masks"] = [mask.data[0] for mask in _split_mask_array(mask_array)]
    elif masks is not None:
        result["masks"] = masks  # Note - this doesn't unpack IceVision objects

    if hasattr(detection, "keypoints"):
        keypoints = detection.keypoints

        result["keypoints"] = []
        result["keypoints_metadata"] = []

        for keypoint in keypoints:
            keypoints_list = []
            for x, y, v in keypoint.xyv:
                keypoints_list.append(
                    {
                        "x": x,
                        "y": y,
                        "visible": v,
                    }
                )
            result["keypoints"].append(keypoints_list)

            # TODO: Unpack keypoints_metadata
            result["keypoints_metadata"].append(keypoint.metadata)

    if getattr(detection, "label_ids", None) is not None:
        result["labels"] = list(detection.label_ids)

    if getattr(detection, "scores", None) is not None:
        result["scores"] = list(detection.scores)

    return result


def from_icevision_record(record: "BaseRecord"):
    sample = {
        DataKeys.METADATA: {
            "size": (record.height, record.width),
        }
    }

    if getattr(record, "record_id", None) is not None:
        sample[DataKeys.METADATA]["image_id"] = record.record_id

    if getattr(record, "filepath", None) is not None:
        sample[DataKeys.METADATA]["filepath"] = record.filepath

    if record.img is not None:
        sample[DataKeys.INPUT] = record.img
        filepath = getattr(record, "filepath", None)
        if filepath is not None:
            sample[DataKeys.METADATA]["filepath"] = filepath
    elif getattr(record, "filepath", None) is not None:
        sample[DataKeys.INPUT] = record.filepath

    sample[DataKeys.TARGET] = from_icevision_detection(record)

    if getattr(record.detection, "class_map", None) is not None:
        sample[DataKeys.METADATA]["class_map"] = record.detection.class_map

    return sample


def from_icevision_predictions(predictions: List["Prediction"]):
    result = []
    for prediction in predictions:
        result.append(from_icevision_detection(prediction.pred))
    return result


class IceVisionTransformAdapter(nn.Module):
    """
    Args:
        transform: list of transformation functions to apply

    """

    def __init__(self, transform: List[Callable]):
        super().__init__()
        self.transform = A.Adapter(transform)

    def forward(self, x):
        record = to_icevision_record(x)
        record = self.transform(record)
        return from_icevision_record(record)


@dataclass
class IceVisionInputTransform(InputTransform):

    image_size: int = 128

    @requires(["image", "icevision"])
    def per_sample_transform(self):
        return IceVisionTransformAdapter([*A.resize_and_pad(self.image_size), A.Normalize()])

    @requires(["image", "icevision"])
    def train_per_sample_transform(self):
        return IceVisionTransformAdapter([*A.aug_tfms(size=self.image_size), A.Normalize()])

    def collate(self) -> Callable:
        return self._identity
